# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:

code from https://github.com/QwenLM/Qwen-7B/blob/main/eval/EVALUATION.md

usage:
wget https://huggingface.co/datasets/haonan-li/cmmlu/resolve/main/cmmlu_v1_0_1.zip
mkdir data/cmmlu
mv cmmlu_v1_0_1.zip data/cmmlu
cd data/cmmlu; unzip cmmlu_v1_0_1.zip
cd ../../
python evaluate_cmmlu.py -d data/cmmlu/
"""
import argparse
import os
from collections import defaultdict
from typing import List

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
from transformers.trainer_utils import set_seed


def load_models_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
    try:
        model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
    except:
        print("GenerationConfig not found, use default config.")
    return model, tokenizer


def format_example(line, include_answer=True):
    example = '问题：' + line['Question']
    for choice in choices:
        example += f'\n{choice}. {line[f"{choice}"]}'

    if include_answer:
        example += '\n答案：' + line["Answer"] + '\n\n'
    else:
        example += '\n答案：'
    return example


def generate_few_shot_prompt(k, subject, dev_df):
    prompt = ''
    if k == -1:
        k = dev_df.shape[0]
    for i in range(k):
        prompt += format_example(
            dev_df.iloc[i, :],
            include_answer=True,
        )
    return prompt


def get_logits(tokenizer, model, inputs: List[str]):
    input_ids = tokenizer(inputs, padding=False)['input_ids']
    input_ids = torch.tensor(input_ids, device=model.device)
    tokens = {'input_ids': input_ids}

    outputs = model(input_ids)['logits']
    logits = outputs[:, -1, :]
    log_probs = torch.nn.functional.softmax(logits, dim=-1)
    return log_probs, {'tokens': tokens}


@torch.no_grad()
def eval_subject(
        model,
        tokenizer,
        subject_name,
        test_df,
        k=5,
        dev_df=None,
        few_shot=False,
        save_result_dir=None,
        **kwargs
):
    result = []
    score = []

    few_shot_prompt = generate_few_shot_prompt(
        k, subject_name, dev_df) if few_shot else []
    all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
    if args.debug:
        print(f"few_shot_prompt: {few_shot_prompt}")

    for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
        question = format_example(row, include_answer=False)
        full_prompt = few_shot_prompt + question

        output, input_info = get_logits(tokenizer, model, [full_prompt])
        assert output.shape[0] == 1
        logits = output.flatten()

        softval = torch.nn.functional.softmax(
            torch.tensor(
                [
                    logits[tokenizer("A")['input_ids']],
                    logits[tokenizer("B")['input_ids']],
                    logits[tokenizer("C")['input_ids']],
                    logits[tokenizer("D")['input_ids']],
                ]
            ),
            dim=0,
        )
        if softval.dtype in {torch.bfloat16, torch.float16}:
            softval = softval.to(dtype=torch.float32)
        probs = softval.detach().cpu().numpy()

        for i, choice in enumerate(choices):
            all_probs[f'prob_{choice}'].append(probs[i])
        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]

        if 'Answer' in row:
            correct = 1 if pred == row['Answer'] else 0
            score.append(correct)
            if args.debug:
                print(f'{question} pred: {pred} ref: {row["Answer"]}')
        result.append(pred)

    if score:
        correct_ratio = 100 * sum(score) / len(score)
        if args.debug:
            print(subject_name, correct_ratio)
    else:
        correct_ratio = 0
    if save_result_dir:
        test_df['model_output'] = result
        for i, choice in enumerate(choices):
            test_df[f'prob_{choice}'] = (all_probs[f'prob_{choice}'])
        if score:
            test_df["correctness"] = score
        os.makedirs(save_result_dir, exist_ok=True)
        test_df.to_csv(os.path.join(
            save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)

    return correct_ratio


def cal_cmmlu(res):
    print('\n\n\n')
    res = {k.split('-')[-1]: float(v) for k, v in res.items()}
    for k, v in TASK_NAME_MAPPING.items():
        avg_acc = np.mean(list(map(lambda x: res[x], v)))
        print(f"{k} acc: {avg_acc:.2f}")
    avg_all_acc = np.mean(list(res.values()))
    print(f"AVERAGE acc: {avg_all_acc:.2f}")


subcategories = {
    "agronomy": ['other'],
    "anatomy": ['biology'],
    "ancient_chinese": ['linguistics', 'china specific'],
    "arts": ['arts'],
    "astronomy": ['physics'],
    "business_ethics": ['business'],
    "chinese_civil_service_exam": ['politics', 'china specific'],
    "chinese_driving_rule": ['other', 'china specific'],
    "chinese_food_culture": ['culture', 'china specific'],
    "chinese_foreign_policy": ['politics', 'china specific'],
    "chinese_history": ['history', 'china specific'],
    "chinese_literature": ['literature', 'china specific'],
    "chinese_teacher_qualification": ['education', 'china specific'],
    "college_actuarial_science": ['math'],
    "college_education": ['education'],
    "college_engineering_hydrology": ['engineering'],
    "college_law": ['law'],
    "college_mathematics": ['math'],
    "college_medical_statistics": ['statistics'],
    "clinical_knowledge": ['other'],
    "college_medicine": ['other'],
    "computer_science": ['computer science'],
    "computer_security": ['other'],
    "conceptual_physics": ['physics'],
    "construction_project_management": ['other', 'china specific'],
    "economics": ['economics'],
    "education": ['education'],
    "elementary_chinese": ['linguistics', 'china specific'],
    "elementary_commonsense": ['other', 'china specific'],
    "elementary_information_and_technology": ['other'],
    "electrical_engineering": ['engineering'],
    "elementary_mathematics": ['math'],
    "ethnology": ['culture', 'china specific'],
    "food_science": ['other'],
    "genetics": ['biology'],
    "global_facts": ['global'],
    "high_school_biology": ['biology'],
    "high_school_chemistry": ['chemistry'],
    "high_school_geography": ['geography'],
    "high_school_mathematics": ['math'],
    "high_school_physics": ['physics'],
    "high_school_politics": ['politics', 'china specific'],
    "human_sexuality": ['other'],
    "international_law": ['law'],
    "journalism": ['sociology'],
    "jurisprudence": ['law'],
    "legal_and_moral_basis": ['other'],
    "logical": ['philosophy'],
    "machine_learning": ['computer science'],
    "management": ['business'],
    "marketing": ['business'],
    "marxist_theory": ['philosophy'],
    "modern_chinese": ['linguistics', 'china specific'],
    "nutrition": ['other'],
    "philosophy": ['philosophy'],
    "professional_accounting": ['business'],
    "professional_law": ['law'],
    "professional_medicine": ['other'],
    "professional_psychology": ['psychology'],
    "public_relations": ['politics'],
    "security_study": ['politics'],
    "sociology": ['culture'],
    "sports_science": ['other'],
    "traditional_chinese_medicine": ['other', 'china specific'],
    "virology": ['biology'],
    "world_history": ['history'],
    "world_religions": ['global'],
}

categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
    "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
    "Social Science": ['linguistics', "business", "politics", "culture", "economics", "geography", "psychology",
                       "education", "sociology"],
    "Other": ["other"],
    "China specific": ["china specific"],
}

TASK_NAME_MAPPING = defaultdict(list)
for k, v in categories.items():
    for subject, subcat in subcategories.items():
        for c in subcat:
            if c in v:
                TASK_NAME_MAPPING[k].append(subject)

choices = ["A", "B", "C", "D"]


def main(args):
    model, tokenizer = load_models_tokenizer(args)

    test_result = {}
    for subject_name in tqdm(subcategories.keys()):
        dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}.csv')
        test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}.csv')
        dev_df = pd.read_csv(dev_file_path)
        test_df = pd.read_csv(test_file_path)

        score = eval_subject(model, tokenizer, subject_name, dev_df=dev_df, test_df=test_df, k=5, few_shot=True,
                             save_result_dir=f"outs/cmmlu_eval_result")
        test_result[subject_name] = score
    cal_cmmlu(test_result)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test HF checkpoint.')
    parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
    parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')

    """Provide extra arguments required for tasks."""
    group = parser.add_argument_group(title='Evaluation options')
    group.add_argument('-d', '--eval_data_path', type=str, required=True,
                       help='Path to eval data')
    group.add_argument("--max-seq-len", type=int, default=2048,
                       help='Size of the output generated text.')
    group.add_argument("--debug", action='store_true', default=False,
                       help='Print infos.')

    args = parser.parse_args()
    set_seed(args.seed)

    main(args)
